import sys
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import sys, os, ast, re
import cv2
import numpy as np
sys.path.append('./')
from utils.result_preprocess import GUI_ODYSSEY_PREPROCESS
from utils.utils_odyssey.qwen_generation_utils import make_context, decode_tokens
from utils.utils_odyssey.modeling_qwen import QWenLMHeadModel
from utils.utils_odyssey.configuration_qwen import QWenConfig
from utils.utils_odyssey.tokenization_qwen import QWenTokenizer

class GUI_Odyssey_Agent:
    def __init__(self, device, accelerator, cache_dir='~/.cache', dropout=0.5, policy_lm=None):
        self.model = None
        self.policy_lm = policy_lm
        self.device = device
        self.accelerator = accelerator
        self.res_pre_process = self._res_pre_process()
        # self.merge_weight()
        self.max_window_size = 6144
        self.chat_format = 'chatml'
    
    def _load_model(self):
        self.model = AutoModelForCausalLM.from_pretrained(self.policy_lm, trust_remote_code=True, torch_dtype=torch.bfloat16).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.policy_lm, trust_remote_code = True)
        self.tokenizer.truncation_side = 'left'
        self.tokenizer.pad_token_id = self.tokenizer.eod_id
        return self.model
    
    def _res_pre_process(self):
        return GUI_ODYSSEY_PREPROCESS()
    
    def merge_weight(self):
        """
        This function is modified from the `merge_weight.py` file of Odyssey repo. It serves as a pre-processing of the model.
        """

        QWENVL_PATH = '/data4/models/OdysseyAgent-random'
        BACKPACK_DIR = "/Agent_ScanKit/utils/utils_odyssey"

        bp_cfg = os.path.join(BACKPACK_DIR, 'config.json')

        tokenizer = AutoTokenizer.from_pretrained(QWENVL_PATH, trust_remote_code=True)
        tokenizer.save_pretrained(BACKPACK_DIR)

        qwen_model = AutoModelForCausalLM.from_pretrained(QWENVL_PATH, device_map=None, trust_remote_code=True)
        cfg = QWenConfig(**json.load(open(bp_cfg)))
        new_qwen_model = QWenLMHeadModel(cfg)

        print("start merging weight...")
        qwen_dict = qwen_model.state_dict()
        odysseyAgent_dict = new_qwen_model.state_dict()
        for k in qwen_dict.keys():
            if k in odysseyAgent_dict:
                odysseyAgent_dict[k] = qwen_dict[k]
        new_qwen_model.load_state_dict(odysseyAgent_dict)
        print("saving...")
        new_qwen_model.save_pretrained(BACKPACK_DIR)
  
    def get_action(self, obs, args):
        if args.probing_method == 'visual_mask':
            image_path = self.visual_mask(obs.get('images')[0], obs, args)
            output = self._get_action(obs, image_path)
            return output
        elif args.probing_method == 'zoom':
            image_path, label = self.zoom_in(obs.get('images')[0], obs)
            output = self._get_action(obs, image_path)
            label['label'] = f"CLICK: ({int(label['label'][0])}, {int(label['label'][1])})"
            return output, label
        elif args.probing_method == 'visual_edit':
            image_path = self.visual_mask(obs.get('images')[0], obs, args)
            output = self._get_action(obs, image_path)
            return output
        else:
            output = self._get_action(obs, image_path)
            return output

    def _get_action(self, obs, image_path):
        obs['question'] = obs['question'].replace(obs.get('images')[0], image_path)
        obs['question'] = re.sub(r'(<img>image-history: ).*(</img>)',
             r'\1' + image_path + r'\2',
             obs['question'])
        raw_texts, _ = make_context(self.tokenizer, obs['question'], system="You are a helpful assistant.", max_window_size=self.max_window_size, chat_format=self.chat_format)

        input_ids = self.tokenizer(raw_texts, return_tensors='pt', padding='longest')
        attention_mask = input_ids.attention_mask.to(self.model.device)
        input_ids = input_ids.input_ids.to(self.model.device)
        

        with torch.no_grad():
            generated_ids = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                do_sample=False,
                num_beams=1,
                length_penalty=1,
                num_return_sequences=1,
                use_cache=True,
                pad_token_id=self.tokenizer.eod_id,   
                eos_token_id=self.tokenizer.eod_id,
                min_new_tokens=1,
                max_new_tokens=30,
            )
        padding_len = input_ids[0].eq(self.tokenizer.pad_token_id).sum().item()
        response = decode_tokens(
            generated_ids[0][padding_len: ],
            self.tokenizer,
            raw_text_len          = len(raw_texts),
            context_length        = input_ids.size(1) - padding_len,
            chat_format           = "chatml",
            verbose               = False,
            errors                = 'replace'
        )
        return response
    
    def visual_mask(self, image, obs, args):
        from PIL import ImageDraw, Image
        image_input = Image.open(image).convert('RGB')
        draw = ImageDraw.Draw(image_input)
        image_width, image_height = image_input.size[0], image_input.size[1]
        if obs.get('dataset_name') == 'AndroidControl':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                for line in file:
                  try:
                      obj = json.loads(line)
                      bbox = obj.get("bbox_pixels", None)
                      class_name = obj.get("class_name", None)
                      if bbox and (class_name == 'android.widget.ImageButton' or class_name == 'android.widget.TextView' or class_name == 'android.widget.ImageView') and obj.get("is_clickable"):
                          x_min, y_min, x_max, y_max = bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]
                          if (
                              0 <= x_min < x_max <= image_width and
                              0 <= y_min < y_max <= image_height
                          ):
                              bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                  except Exception:
                      continue
        elif obs.get('dataset_name') == 'AITZ':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                accessibility_trees_file_data = json.load(file)
            for idx, acc_data in enumerate(accessibility_trees_file_data):
                if acc_data['image_path'] in obs.get('images')[0]:
                    bbox = ast.literal_eval(accessibility_trees_file_data[idx]['ui_positions'])
            bbox_data = [[y, x, h, w] for (x, y, w, h) in bbox]
        else:
            bbox_data = obs.get('bbox')
            bbox_data = [bbox_data[0]/1000*image_width, bbox_data[1]/1000*image_height, bbox_data[2]/1000*image_width, bbox_data[3]/1000*image_height]
            bbox_data = [[bbox_data[0], bbox_data[1], bbox_data[2]-bbox_data[0], bbox_data[3]-bbox_data[1]]]
        gt = self.res_pre_process.extract_action(obs['label'])
        gt = self.res_pre_process.extract_coordinates(gt)
        _, bbox_list, point = self.remove_containing_bboxes(bbox_list=bbox_data, gt=gt, image_size=[image_width, image_height]) 
        if args.probing_method == 'visual_mask':
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    x, y, w, h = bbox
                    draw.rectangle([x, y, x+w, y+h], fill="black")
            else:
                r = args.mask_object_ratio
                draw.rectangle([point[0]-r, point[1]-r, point[0]+r, point[1]+r], fill="black")
        else:
            image_cv = np.array(image_input)
            image_input = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
                    mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 255
                    image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
                # image_input = Image.fromarray(image_input)
            r = args.mask_object_ratio
            mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
            x, y = point
            x_min = int(x - r)
            y_min = int(y - r)
            x_max = int(x + r)
            y_max = int(y + r)
            mask[y_min:y_max, x_min:x_max] = 255

            image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            image_input = Image.fromarray(image_input)
                
        image_input.save(f"/Agent_ScanKit/datasets/json/visual_mask/gui_odyssey_images/{'_'.join(image.split('/')[-5:])}")
        return f"/Agent_ScanKit/datasets/json/visual_mask/gui_odyssey_images/{'_'.join(image.split('/')[-5:])}"

    def remove_containing_bboxes(self, bbox_list, gt, image_size):
        click_x, click_y = gt[0] / 1000 * image_size[0], gt[1] / 1000 * image_size[1]
        out_bbox_list = []
        in_bbox_list = []
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                if not (x <= click_x <= x+w and y <= click_y <= y+h):
                    out_bbox_list.append(bbox)
                else:
                    in_bbox_list.append(bbox)
        return out_bbox_list, in_bbox_list, (click_x, click_y)
    
    def zoom_in(self, image, obs):
        from PIL import Image
        pil_image = Image.open(image).convert('RGB')
        try:
            content = obs['label']
        except (IndexError, KeyError, TypeError):
            raise ValueError("Invalid message format in obs")

        ground_truth = self.res_pre_process.extract_action(content)
        bbox = obs.get("bbox")  
        w, h = pil_image.size

        click_x, click_y = self.res_pre_process.extract_coordinates(ground_truth)
        click_x, click_y = click_x/1000*w, click_y/1000*h
        mid_x, mid_y = w // 2, h // 2
        if click_x < mid_x and click_y < mid_y:
            region = (0, 0, mid_x, mid_y)
        elif click_x >= mid_x and click_y < mid_y:
            region = (mid_x, 0, w, mid_y)
        elif click_x < mid_x and click_y >= mid_y:
            region = (0, mid_y, mid_x, h)
        else:
            region = (mid_x, mid_y, w, h)

        cropped = pil_image.crop(region)
        zoomed_image = cropped.resize((w, h), Image.LANCZOS)

        def transform_coord(x, y, region, w, h):
            rel_x, rel_y = x - region[0], y - region[1]
            scale_x = w / (region[2] - region[0])
            scale_y = h / (region[3] - region[1])
            new_x = int(rel_x * scale_x)
            new_y = int(rel_y * scale_y)
            return new_x, new_y
            
            
        new_click_x, new_click_y = transform_coord(click_x, click_y, region, w, h)
        norm_click_x = new_click_x / w * 1000
        norm_click_y = new_click_y / h * 1000
        
        new_bbox = None
        if bbox is not None:
            bbox = [bbox[0]/1000*w, bbox[1]/1000*h, bbox[2]/1000*w, bbox[3]/1000*h]
            x_min, y_min = transform_coord(bbox[0], bbox[1], region, w, h)
            x_max, y_max = transform_coord(bbox[2], bbox[3], region, w, h)
            new_bbox = [x_min/w*1000, y_min/h*1000, x_max/w*1000, y_max/h*1000]

        zoomed_image.save(f"/Agent_ScanKit/datasets/json/zoom/gui_odyssey_images/{'_'.join(image.split('/')[-5:])}")

        
        return f"/Agent_ScanKit/datasets/json/zoom/gui_odyssey_images/{'_'.join(image.split('/')[-5:])}", {"label": [norm_click_x, norm_click_y], "bbox": new_bbox}

    




        
        